import os
os.chdir("../")
import sys
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
import numpy as np
import matplotlib.pyplot as plt

CB91_Blue = '#2CBDFE'
CB91_Green = '#47DBCD'
CB91_Pink = '#F3A0F2'
CB91_Purple = '#9D2EC5'
CB91_Violet = '#661D98'
CB91_Amber = '#F5B14C'
seventh_color = "#4b97ec"
color_list = [CB91_Blue, CB91_Pink, CB91_Green, CB91_Amber, CB91_Purple, CB91_Violet, seventh_color]


RESULT_PATH = "results/results_ecs_long/results_ecs_long/"
RESULT_PATH_LONGER = "results/results_ecs_long/results_ecs_longer_dis_hard_discrete/"
RESULT_PATH_40k = "results/results_ecs_long/results_ecs_40k_dis_hard_discrete/"
NUM_RUNS = 4
NUM_EPISODES = 80000
episodes_indices = [i for i in range(0, NUM_EPISODES, 20)]

iql_reward = np.load(RESULT_PATH + "iql_running_reward_argmax.npy")
iql_reward_mean = np.mean(iql_reward, axis = 0)
iql_std= np.std(iql_reward, axis = 0) / np.sqrt(NUM_RUNS)
print(np.mean(iql_reward))

obl_reward = np.load(RESULT_PATH + "obl_running_reward_argmax.npy")
obl_reward_mean = np.mean(obl_reward, axis = 0)
obl_reward_std= np.std(obl_reward, axis = 0) / np.sqrt(NUM_RUNS)
print(np.mean(obl_reward_mean))

# obl_mi_util_obl_reward = np.load(RESULT_PATH + "obl_running_reward_mi_log2_mi_loss_argmax.npy")
# obl_mi_util_obl_reward_mean = np.mean(obl_mi_util_obl_reward, axis = 0)
# print(np.mean(obl_mi_util_obl_reward_mean))
#
obl_mi_util_iql_reward = np.load(RESULT_PATH_LONGER + "obl_running_reward_mi_log2_mi_loss_util_iql_argmax.npy")
obl_mi_util_iql_reward_mean = np.mean(obl_mi_util_iql_reward, axis = 0)
obl_mi_util_iql_reward_std = np.std(obl_mi_util_iql_reward, axis = 0) / np.sqrt(NUM_RUNS)
print(np.mean(obl_mi_util_iql_reward_mean))

obl_mi_util_dial_iql_reward = np.load(RESULT_PATH_LONGER + "obl_dial_running_reward_mi_log2_mi_loss_argmax.npy")
obl_mi_util_dial_iql_reward_mean = np.mean(obl_mi_util_dial_iql_reward, axis = 0)
obl_mi_util_dial_iql_reward_std = np.std(obl_mi_util_dial_iql_reward, axis = 0) / np.sqrt(NUM_RUNS)
print(np.mean(obl_mi_util_dial_iql_reward_mean))

# Plot Mean
plt.plot(episodes_indices, iql_reward_mean.squeeze(), label = "IQL", color = color_list[-1])
plt.fill_between(episodes_indices, iql_reward_mean.squeeze()-iql_std, iql_reward_mean.squeeze()+iql_std, facecolor = color_list[-1], alpha = 0.3)

plt.plot(episodes_indices, obl_reward_mean.squeeze(), label = "OBL", color = color_list[-3])
plt.fill_between(episodes_indices, obl_reward_mean.squeeze()-obl_reward_std, obl_reward_mean.squeeze()+obl_reward_std, facecolor = color_list[-3], alpha = 0.3)

# plt.plot(episodes_indices, obl_mi_util_obl_reward_mean.squeeze(), label = "OBL + MI + OBL Util")
plt.plot(episodes_indices, obl_mi_util_iql_reward_mean.squeeze(), label = "OBL + MI + IQL Util", color = color_list[1])
plt.fill_between(episodes_indices, obl_mi_util_iql_reward_mean.squeeze()-obl_mi_util_iql_reward_std, obl_mi_util_iql_reward_mean.squeeze()+obl_mi_util_iql_reward_std, facecolor = color_list[1], alpha = 0.3)

plt.plot(episodes_indices, obl_mi_util_dial_iql_reward_mean.squeeze(), label = "OBL + MI + DIAL Util", color = color_list[2])
plt.fill_between(episodes_indices, obl_mi_util_dial_iql_reward_mean.squeeze()-obl_mi_util_dial_iql_reward_std, obl_mi_util_dial_iql_reward_mean.squeeze()+obl_mi_util_dial_iql_reward_std, facecolor = color_list[2], alpha = 0.3)

ax = plt.gca()
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

plt.legend()
plt.ylabel("Task Running Reward")
plt.xlabel("Episodes")
plt.axvline(8000, 0, 1, color = color_list[3])
plt.show()

# iql_reward_ir = np.load(RESULT_PATH + "iql_running_reward_ir_argmax.npy")
# iql_reward_ir_mean = np.mean(iql_reward_ir, axis = 0)
# print(np.mean(iql_reward_ir_mean))
#
# iql_reward = np.load(RESULT_PATH + "iql_running_reward_argmax.npy")
# iql_reward_mean = np.mean(iql_reward, axis = 0)
# print(np.mean(iql_reward_mean))
#
# obl_reward_mi = np.load(RESULT_PATH + "obl_running_reward_mi_log2_argmax.npy")
# obl_reward_mi_mean = np.mean(obl_reward_mi, axis = 0)
# print(np.mean(obl_reward_mi_mean))
#
# obl_reward = np.load(RESULT_PATH + "obl_running_reward_argmax.npy")
# obl_reward_mean = np.mean(obl_reward, axis = 0)
# print(np.mean(obl_reward_mean))
#
# # Plot Mean
# plt.plot(episodes_indices, obl_reward_mi_mean.squeeze(), label = "OBL with MI reward")
# plt.plot(episodes_indices, obl_reward_mean.squeeze(), label = "OBL")
# plt.plot(episodes_indices, iql_reward_ir_mean.squeeze(), label = "IQL with intermediate reward")
# plt.plot(episodes_indices, iql_reward_mean.squeeze(), label = "IQL")
# plt.legend()
# plt.ylabel("Task Reward")
# plt.xlabel("Episodes")
# plt.show()
